#!/usr/bin/env python3
# SPDX-License-Identifier: MIT

from __future__ import annotations

import argparse
import dataclasses
import sys
from collections import defaultdict
from contextlib import AbstractContextManager, nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import TextIO, cast


@dataclass
class Header:
    text: str
    key: tuple[str, ...]

    @classmethod
    def empty(cls) -> Header:
        return Header("", ())

    @classmethod
    def parse(cls, raw: str) -> Header:
        return Header(raw, tuple(raw.split()))

    def set_merge_mode(self, merge_mode: bool) -> Header:
        if merge_mode:
            key = self.key + ("+",)
            return dataclasses.replace(self, key=key)
        else:
            return self


def handle_file(path: Path) -> tuple[Header, Path]:
    """
    Return a tuple of (header, path) for the file at path.
    If the file does not have a header, the header is the empty string.
    """
    with open(path) as fd:
        file_header = fd.readline().split("//")[0]
        if file_header.startswith("! "):
            header = Header.parse(file_header)
            if "$" in file_header or "include" in file_header:
                # Do not process group definition further
                return header, path
            # Get merge mode: we do not want to mix rules sets with
            # explicit and implicit merge modes.
            has_explicit_merge_mode: bool | None = None
            for line in fd:
                if line.startswith("! "):
                    break
                entry = line.split("//")[0].split()
                if not entry:
                    continue
                section = entry[-1]
                if section.startswith("+") or section.startswith("|"):
                    has_explicit_merge_mode = True
                # Ensure any explicit merge mode takes precedence
                elif has_explicit_merge_mode is None:
                    has_explicit_merge_mode = False
            return header.set_merge_mode(bool(has_explicit_merge_mode)), path
        else:
            return Header.empty(), path


def merge(dest: TextIO, files):
    """
    Merge the content of all files into the file dest.

    The first line of each file is an optional section header in the form
    e.g.
       ! model =  keycodes
    Where two sections have identical headers, the second header is skipped.

    Special case are header-less files which we store with the empty string
    as header, these need to get written out first.
    """

    def sort_basename(path: Path) -> str:
        return path.name

    # sort the file list by basename
    files.sort(key=sort_basename)

    # Group files by their header
    # Pre-populate with the empty header so it's the first one to be written
    # out. We use section_names to keep the same order as we get the files
    # passed in (superfluous with python 3.6+ since the dict keeps the
    # insertion order anyway) and the original header text.
    section_names = [Header.empty()]
    sections: dict[tuple[str, ...], list[Path]] = defaultdict(
        list, ((h.key, []) for h in section_names)
    )
    for path in files:
        # Files may exist in srcdir or builddir, depending whether they're
        # generated
        header, path = handle_file(path)
        if header.key not in sections:
            section_names.append(header)
        sections[header.key].append(path)

    for header in section_names:
        if header.text:
            dest.write("\n")
            dest.write(header.text)
        for f in sections[header.key]:
            with open(f) as fd:
                if header.text:
                    fd.readline()  # drop the header
                dest.write(fd.read())


if __name__ == "__main__":
    parser = argparse.ArgumentParser("rules file merge script")
    parser.add_argument("--dest", type=Path, default=None)
    parser.add_argument("--srcdir", type=Path)
    parser.add_argument("--builddir", type=Path)
    parser.add_argument("files", nargs="+", type=Path)
    ns = parser.parse_args()

    if ns.dest is None:
        dest: AbstractContextManager = nullcontext(sys.stdout)
    else:
        dest = cast(Path, ns.dest).open("wt")

    with dest as fd:
        basename = Path(sys.argv[0]).name
        fd.write(
            "// DO NOT EDIT THIS FILE - IT WAS AUTOGENERATED BY {} FROM rules/*.part\n".format(
                basename
            )
        )
        fd.write("//\n")

        def find_file(f: Path) -> Path:
            if ns.builddir:
                path = cast(Path, ns.builddir) / f
                if path.exists():
                    return path
            if ns.srcdir:
                path = cast(Path, ns.srcdir) / f
                if path.exists():
                    return path
            return f

        merge(fd, [find_file(f) for f in ns.files])
